{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Getting Started with Cloud TPUs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Before We Start\n",
    "\n",
    "This notebook requires that you run it on a GCE VM in a GCP project that has Cloud TPU quota. If you are not using this notebook on a pre-built virtual machine image, here is how you can start a new GCE VM with the right settings:\n",
    "\n",
    "1. Create a new GCE VM with the following settings:\n",
    "  1. Zone: us-central1-c or us-central1-f\n",
    "  2. Machine Type: n1-standard-8\n",
    "  3. Operating System: Ubuntu 16.04 LTS with 64 GB of persistent disk\n",
    "  4. Identity and API access: Enable \"Allow full access to all Cloud APIs\"\n",
    "  5. Networking > Network Tag: `tpu-jupyterhub-demo`\n",
    "2. Create a network rule under VPC network > Firewall Rules in the GCE control plane with the following settings:\n",
    "  1. Name: `tpu-jupyterhub-demo`\n",
    "  2. Target tags: `tpu-jupyterhub-demo`\n",
    "  3. Source IP ranges: `0.0.0.0/0`\n",
    "  4. Protocols and ports: Specified protocols and ports, `tcp:6006,8888`\n",
    "3. SSH into the GCE VM you have created in Step 1 and run the following:\n",
    "  a. `sudo apt-get update`\n",
    "  b. `sudo apt-get -y install python3 python3-pip`\n",
    "  c. `sudo -H pip3 install jupyter tf-nightly google-api-python-client`\n",
    "4. Start Jupyter in the GCE VM with `jupyter notebook --no-browser --ip=0.0.0.0`\n",
    "5. Navigate to http://IP.OF.MY.VM:8888/?token=THE.TOKEN.DISPLAYED.ON.THE.COMMANDLINE on your favorite browser.\n",
    "6. Upload this notebook to your Jupyter notebook by clicking on the `Upload` button."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration\n",
    "\n",
    "Please modify the following environment variables as it is required for the notebook.\n",
    "\n",
    "* `GCE_PROJECT_NAME`: The name of the GCE project this VM (and your Cloud TPU) starts in.\n",
    "* `TPU_ZONE`: The GCE zone in which you want your Cloud TPU to start in.\n",
    "* `TPU_NAME`: The name of the Cloud TPU\n",
    "* `TPU_IP_RANGE`: The IP address range for the Cloud TPU\n",
    "* `GCS_DATA_PATH`: The GCS path where we will store sample test data for Cloud TPUs. As GCS bucket namespaces are global, you may need to change this.\n",
    "* `GCS_CKPT_PATH`: The GCS path where we will store sample checkpoint data for Cloud TPUs.  As GCS bucket namespaces are global, you may need to change this.\n",
    "\n",
    "**Note:** We will grant _Storage Admin_ permissions to Cloud TPUs on the GCS buckets specified in `GCS_DATA_PATH` and `GCS_CKPT_PATH`. We encourage you to specify new buckets or test buckets containing non-production data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env GCE_PROJECT_NAME my-sample-tpu-project\n",
    "%env TPU_ZONE us-central1-f\n",
    "%env TPU_NAME demo-tpu\n",
    "%env TPU_IP_RANGE 10.240.1.0/29\n",
    "%env GCS_DATA_PATH gs://cloud-tpu-data-bucket/mnist/\n",
    "%env GCS_CKPT_PATH gs://cloud-tpu-checkpoint-bucket/mnist/\n",
    "    \n",
    "# Automatically get bucket name from GCS paths\n",
    "import os\n",
    "os.environ['GCS_DATA_BUCKET'] = os.environ['GCS_DATA_PATH'][5:].split('/')[0]\n",
    "os.environ['GCS_CKPT_BUCKET'] = os.environ['GCS_CKPT_PATH'][5:].split('/')[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a new Cloud TPU\n",
    "\n",
    "You can create a new Cloud TPU by running the command in the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud config set compute/zone $TPU_ZONE\n",
    "!gcloud alpha compute tpus create $TPU_NAME --range=$TPU_IP_RANGE --accelerator-type=tpu-v2 --version=nightly --zone=$TPU_ZONE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create GCS Buckets for Cloud TPU\n",
    "\n",
    "Here, we will create two GCS buckets -- one for training/test data (`GCS_DATA_BUCKET`), and the other for TensorFlow checkpoint and TensorBoard metric data (`GCS_CKPT_BUCKET`).\n",
    "\n",
    "The first two commands creates the buckets in a single region for maximum performance, and the final command grants the Cloud TPU Service Account owner access to the bucket so that it can read from and write to the bucket."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gsutil mb -c regional -l us-central1 gs://$GCS_DATA_BUCKET\n",
    "!gsutil mb -c regional -l us-central1 gs://$GCS_CKPT_BUCKET\n",
    "\n",
    "!gsutil iam ch serviceAccount:`gcloud alpha compute tpus describe $TPU_NAME | grep serviceAccount | cut -d' ' -f2`:admin gs://$GCS_DATA_BUCKET gs://$GCS_CKPT_BUCKET && echo 'Successfully set permissions!'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connect to Cloud TPU and Run a Simple *AX+Y* Calculation\n",
    "\n",
    "The following code connects to your GCE VM (which is already running the proxy), resolves the IP address of the TPU using TPU Cluster Resovers, connects to the Cloud TPU with `tf.Session`, and run a simple calculation on TPUs. \n",
    "\n",
    "You should see a 3x3 array of random numbers being printed out if the command is successful. This computation may take up to 60 seconds to run to completion due to TPU initialization overheads."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tensorflow.contrib import tpu\n",
    "from tensorflow.contrib.cluster_resolver import TPUClusterResolver\n",
    "\n",
    "def axy_computation(a, x, y):\n",
    "  return a * x + y\n",
    "\n",
    "inputs = [\n",
    "    3.0,\n",
    "    tf.random_uniform([3, 3], 0, 1, tf.float32),\n",
    "    tf.random_uniform([3, 3], 0, 1, tf.float32),\n",
    "]\n",
    "tpu_computation = tpu.rewrite(axy_computation, inputs)\n",
    "\n",
    "tpu_cluster_resolver = TPUClusterResolver([os.environ['TPU_NAME']], zone=os.environ['TPU_ZONE'], project=os.environ['GCE_PROJECT_NAME'])\n",
    "tpu_grpc_url = tpu_cluster_resolver.get_master()\n",
    "\n",
    "with tf.Session(tpu_grpc_url) as sess:\n",
    "  print('Initializing TPU...')\n",
    "  sess.run(tpu.initialize_system())\n",
    "  print('Initializing global variables...')\n",
    "  sess.run(tf.global_variables_initializer())\n",
    "  print('Executing TPU operation...')\n",
    "  output = sess.run(tpu_computation)\n",
    "  print(output)\n",
    "  print('Shutting down TPU...')\n",
    "  sess.run(tpu.shutdown_system())\n",
    "  print('Done!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating and Uploading TFRecords from MNIST Test Data\n",
    "\n",
    "The script below downloads the MNIST data from http://yann.lecun.com/exdb/mnist/ and creates TFRecord files from it. We then upload it to the GCS data bucket (`GCS_DATA_BUCKET`) that we created earlier. If the run is successful, you should see a message similar to `Operation completed over 3 objects/59.5 MiB.` as the last line of the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.contrib.learn.python.learn.datasets import mnist\n",
    "\n",
    "\n",
    "def _int64_feature(value):\n",
    "  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))\n",
    "\n",
    "\n",
    "def _bytes_feature(value):\n",
    "  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "\n",
    "def convert_mnist_data_to_tfrecord(data_set, name):\n",
    "  \"\"\"Converts a dataset to tfrecords.\"\"\"\n",
    "  images = data_set.images\n",
    "  labels = data_set.labels\n",
    "  num_examples = data_set.num_examples\n",
    "\n",
    "  if images.shape[0] != num_examples:\n",
    "    raise ValueError('Images size %d does not match label size %d.' %\n",
    "                     (images.shape[0], num_examples))\n",
    "  rows = images.shape[1]\n",
    "  cols = images.shape[2]\n",
    "  depth = images.shape[3]\n",
    "\n",
    "  filename = os.path.join(os.environ['GCS_DATA_PATH'], name + '.tfrecords')\n",
    "  print('Writing', filename)\n",
    "  with tf.python_io.TFRecordWriter(filename) as writer:\n",
    "    for index in range(num_examples):\n",
    "      image_raw = images[index].tostring()\n",
    "      example = tf.train.Example(\n",
    "          features=tf.train.Features(\n",
    "              feature={\n",
    "                  'height': _int64_feature(rows),\n",
    "                  'width': _int64_feature(cols),\n",
    "                  'depth': _int64_feature(depth),\n",
    "                  'label': _int64_feature(int(labels[index])),\n",
    "                  'image_raw': _bytes_feature(image_raw)\n",
    "              }))\n",
    "      writer.write(example.SerializeToString())\n",
    "\n",
    "\n",
    "def convert_mnist_to_tfrecord():\n",
    "  # Get the data.\n",
    "  data_sets = mnist.read_data_sets('/tmp',\n",
    "                                   dtype=tf.uint8,\n",
    "                                   reshape=False,\n",
    "                                   validation_size=5000)\n",
    "\n",
    "  # Convert to Examples and write the result to TFRecords.\n",
    "  convert_mnist_data_to_tfrecord(data_sets.train, 'train')\n",
    "  convert_mnist_data_to_tfrecord(data_sets.validation, 'validation')\n",
    "  convert_mnist_data_to_tfrecord(data_sets.test, 'test')\n",
    "\n",
    "convert_mnist_to_tfrecord()\n",
    "print(\"Finished writing TFRecords to %s\" % os.environ['GCS_DATA_PATH'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining a Simple Neural Network Model with TPU Estimators\n",
    "\n",
    "We can define a simple model with TensorFlow Estimators that can train on the MNIST dataset to identify images. As part of using TensorFlow estimators, we need to create a model function that defines the model (`mnist_model_fn`), and a input function to process the inputs (the result of `mnist_get_input_fn`).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.contrib.tpu.python.tpu import tpu_config\n",
    "from tensorflow.contrib.tpu.python.tpu import tpu_estimator\n",
    "from tensorflow.contrib.tpu.python.tpu import tpu_optimizer\n",
    "\n",
    "def mnist_metric_fn(labels, logits):\n",
    "  \"\"\"Evaluation metric Fn which runs on CPU.\"\"\"\n",
    "  predictions = tf.argmax(logits, 1)\n",
    "  return {\n",
    "      \"accuracy\": tf.metrics.precision(\n",
    "          labels=labels, predictions=predictions),\n",
    "  }\n",
    "\n",
    "\n",
    "def mnist_model_fn(features, labels, mode, params):\n",
    "  \"\"\"A simple CNN.\"\"\"\n",
    "  del params\n",
    "\n",
    "  if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "    raise RuntimeError(\"mode {} is not supported yet\".format(mode))\n",
    "\n",
    "  input_layer = tf.reshape(features, [-1, 28, 28, 1])\n",
    "  conv1 = tf.layers.conv2d(\n",
    "      inputs=input_layer,\n",
    "      filters=32,\n",
    "      kernel_size=[5, 5],\n",
    "      padding=\"same\",\n",
    "      activation=tf.nn.relu)\n",
    "  pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)\n",
    "  conv2 = tf.layers.conv2d(\n",
    "      inputs=pool1,\n",
    "      filters=64,\n",
    "      kernel_size=[5, 5],\n",
    "      padding=\"same\",\n",
    "      activation=tf.nn.relu)\n",
    "  pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)\n",
    "  pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])\n",
    "  dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu)\n",
    "  dropout = tf.layers.dropout(\n",
    "      inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)\n",
    "  logits = tf.layers.dense(inputs=dropout, units=10)\n",
    "  onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)\n",
    "\n",
    "  loss = tf.losses.softmax_cross_entropy(\n",
    "      onehot_labels=onehot_labels, logits=logits)\n",
    "\n",
    "  if mode == tf.estimator.ModeKeys.EVAL:\n",
    "    return tpu_estimator.TPUEstimatorSpec(\n",
    "        mode=mode,\n",
    "        loss=loss,\n",
    "        eval_metrics=(mnist_metric_fn, [labels, logits]))\n",
    "\n",
    "  # Train.\n",
    "  learning_rate = tf.train.exponential_decay(0.05,\n",
    "                                             tf.train.get_global_step(), 100000,\n",
    "                                             0.96)\n",
    "\n",
    "  optimizer = tpu_optimizer.CrossShardOptimizer(\n",
    "      tf.train.GradientDescentOptimizer(learning_rate=learning_rate))\n",
    "\n",
    "  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())\n",
    "  return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)\n",
    "\n",
    "\n",
    "def mnist_get_input_fn(filename):\n",
    "  \"\"\"Returns an `input_fn` for train and eval.\"\"\"\n",
    "\n",
    "  def input_fn(params):\n",
    "    \"\"\"A simple input_fn using the experimental input pipeline.\"\"\"\n",
    "    # Retrieves the batch size for the current shard. The # of shards is\n",
    "    # computed according to the input pipeline deployment. See\n",
    "    # `tf.contrib.tpu.RunConfig` for details.\n",
    "    batch_size = params[\"batch_size\"]\n",
    "\n",
    "    def parser(serialized_example):\n",
    "      \"\"\"Parses a single tf.Example into image and label tensors.\"\"\"\n",
    "      features = tf.parse_single_example(\n",
    "          serialized_example,\n",
    "          features={\n",
    "              \"image_raw\": tf.FixedLenFeature([], tf.string),\n",
    "              \"label\": tf.FixedLenFeature([], tf.int64),\n",
    "          })\n",
    "      image = tf.decode_raw(features[\"image_raw\"], tf.uint8)\n",
    "      image.set_shape([28 * 28])\n",
    "      # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]\n",
    "      image = tf.cast(image, tf.float32) * (1. / 255) - 0.5\n",
    "      label = tf.cast(features[\"label\"], tf.int32)\n",
    "      return image, label\n",
    "\n",
    "    dataset = tf.data.TFRecordDataset(\n",
    "        filename, buffer_size=None)\n",
    "    dataset = dataset.map(parser).cache().repeat()\n",
    "    dataset = dataset.apply(\n",
    "        tf.contrib.data.batch_and_drop_remainder(batch_size))\n",
    "    images, labels = dataset.make_one_shot_iterator().get_next()\n",
    "    return images, labels\n",
    "  return input_fn\n",
    "\n",
    "print(\"Estimator-based MNIST Model Defined Successfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running a TPU Estimator-based Model\n",
    "\n",
    "To run a TPU Estimator-based model, we can define some configurations (`tpu_config.RunConfig`) and create a TPUEstimator using the `model_fn`, and then call `train` and `evaluate` on the Estimator. This computation may take up to 120 seconds to run to completion due to TPU initialization overheads."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.contrib.cluster_resolver import TPUClusterResolver\n",
    "import os\n",
    "import time\n",
    "\n",
    "tpu_cluster_resolver = TPUClusterResolver([os.environ['TPU_NAME']], zone=os.environ['TPU_ZONE'], project=os.environ['GCE_PROJECT_NAME'])\n",
    "tpu_grpc_url = tpu_cluster_resolver.get_master()\n",
    "\n",
    "batch_size = 128\n",
    "train_file = os.path.join(os.environ['GCS_DATA_PATH'], \"train.tfrecords\")\n",
    "train_steps = 1000\n",
    "\n",
    "eval_file = os.path.join(os.environ['GCS_DATA_PATH'], \"validation.tfrecords\")\n",
    "eval_steps = 100\n",
    "\n",
    "model_dir = os.path.join(os.environ['GCS_CKPT_PATH'], str(int(time.time()))) + \"/\"\n",
    "iterations = 50\n",
    "num_shards = 8\n",
    "\n",
    "os.environ['MNIST_MODEL_DIR'] = model_dir\n",
    "  \n",
    "tf.logging.set_verbosity(tf.logging.INFO)\n",
    "\n",
    "run_config = tpu_config.RunConfig(\n",
    "    master=tpu_grpc_url,\n",
    "    evaluation_master=tpu_grpc_url,\n",
    "    model_dir=model_dir,\n",
    "    session_config=tf.ConfigProto(\n",
    "        allow_soft_placement=True, log_device_placement=True),\n",
    "    tpu_config=tpu_config.TPUConfig(iterations, num_shards),)\n",
    "\n",
    "estimator = tpu_estimator.TPUEstimator(\n",
    "    model_fn=mnist_model_fn,\n",
    "    use_tpu=True,\n",
    "    train_batch_size=batch_size,\n",
    "    eval_batch_size=batch_size,\n",
    "    config=run_config)\n",
    "\n",
    "estimator.train(input_fn=mnist_get_input_fn(train_file),\n",
    "                max_steps=train_steps)\n",
    "\n",
    "if eval_steps:\n",
    "  estimator.evaluate(input_fn=mnist_get_input_fn(eval_file),\n",
    "                     steps=eval_steps)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Graphs in TensorBoard\n",
    "\n",
    "You can visualize the model and training details in TensorBoard. To launch Tensorboard, simply pass in the GCS path of the model directory into TensorBoard. Make sure that you have port 6006 open.\n",
    "\n",
    "To stop TensorBoard, click on Kernel > Interrupt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!echo Visit http://`curl -H \"Metadata-Flavor: Google\" http://metadata/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip 2> /dev/null`:6006 for TensorBoard\n",
    "!which pip3 && pip3 install html5lib==0.99999999  # workaround for TensorBoard dependency error in Python 3\n",
    "!tensorboard --logdir=$MNIST_MODEL_DIR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deleting your Cloud TPU\n",
    "\n",
    "To delete the Cloud TPU you have created, simply run the command below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!yes | gcloud alpha compute tpus delete $TPU_NAME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### That's All\n",
    "\n",
    "Congratulations! You have finished our tutorial. In this tutorial, you have\n",
    "\n",
    "1. Created a new Cloud TPU using the `gcloud` command.\n",
    "2. Created two GCS buckets and added permissions for the Cloud TPU to read from/write to these buckets.\n",
    "3. Created TFRecords files suitable for Cloud TPU consumption from the MNIST dataset.\n",
    "4. Ran a simple computation to verify that your Cloud TPU works.\n",
    "5. Trained a MNIST image recognition model and evaluated the results.\n",
    "6. Visualized the results of training using TensorBoard.\n",
    "7. Deleted a Cloud TPU after you are done with training and evaluation.\n",
    "\n",
    "For more information about Cloud TPUs, you can take a look at the official [Cloud TPU documentation](https://cloud.google.com/tpu/docs/)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
